ScatterElements

返回一个新tensor,根据指定索引和更新值对input中的元素进行指定操作(替换或相加)。不支持隐式类型转换。举例:一个三维输入tensor的返回为:

output[indices[i][j][k]][j][k] = updates[i][j][k] #if axis == 0, reduction == "none"
output[i][indices[i][j][k]][k] += updates[i][j][k] #if axis == 1, reduction == "add"
output[i][j][indices[i][j][k]] = updates[i][j][k] #if axis == 2, reduction == "none"
输入:
  • input - 输入数据的地址

  • indices - 指定索引。

  • updates - 更新值。

  • param - 算子计算所需参数的结构体。其各成员见下述。

  • core_mask - 核掩码。

ScatterElementsParameter定义:

 1typedef struct ScatterElementsParameter {
 2    int* indices_stride_; // 对应于indices数组每一维度的步长
 3    int* output_stride_; // 对应于output数组每一维度的步长
 4    int input_dims_; // 输入张量的维度数
 5    int axis_; // 指定索引所在的轴
 6    int input_axis_size_; // 索引所在轴的元素数
 7    int indices_total_num_; // indices数组的总元素数
 8    int input_total_num_; // input数组的总元素数
 9    int reduction_type_; // 规约类型,0代表none,1代表add
10} ScatterElementsParameter;
输出:
  • output - 输出地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持int8, int16, int32, fp32, fp64, cplx64, cplx128

  • MT7004 支持fp16, fp32, int16, int32, cplx64

  • 如果 indices 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。

  • 如果 indices 的值超出 input 索引上下界,则相应的 updates 不会更新到 input,也不会抛出索引错误。

共享存储版本:

void i8_scatter_elements_s(int8_t *input, int8_t *output, int *indices, int8_t *updates, ScatterElementsParameter *param, int core_mask)
void i16_scatter_elements_s(int16_t *input, int16_t *output, int *indices, int16_t *updates, ScatterElementsParameter *param, int core_mask)
void i32_scatter_elements_s(int *input, int *output, int *indices, int *updates, ScatterElementsParameter *param, int core_mask)
void hp_scatter_elements_s(half *input, half *output, int *indices, half *updates, ScatterElementsParameter *param, int core_mask)
void fp_scatter_elements_s(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
void dp_scatter_elements_s(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
void c64_scatter_elements_s(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
void c128_scatter_elements_s(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)

C调用示例:

 1void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) {
 2    param->indices_stride_[param->input_dims_ - 1] = 1;
 3    int i;
 4    for (i = param->input_dims_ - 1; i > 0; --i) {
 5        param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i];
 6    }
 7    param->output_stride_[param->input_dims_ - 1] = 1;
 8    for (i = param->input_dims_ - 1; i > 0; --i) {
 9        param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i];
10    }
11    param->indices_total_num_ = 1;
12    for (i = 0; i < param->input_dims_; i++) {
13        param->indices_total_num_ *= indices_shape[i];
14    }
15    param->input_total_num_ = 1;
16    for (i = 0; i < param->input_dims_; i++) {
17        param->input_total_num_ *= input_shape[i];
18    }
19    param->input_axis_size_ = input_shape[param->axis_];
20}
21
22void TestScatterElementsSMC(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) {
23    int core_num = GetCoreNum(core_mask);
24    int core_id = get_core_id();
25    int logic_core_id = GetLogicCoreId(core_mask, core_id);
26    void* input_data = (void*)0x88000000;
27    void* output_data = (void*)0x98000000;
28    int* indices_data = (int*)0xA8000000;
29    void* updates_data = (void*)0xB8000000;
30    ScatterElementsParameter* param = (ScatterElementsParameter*)0xC8000000;
31    if (logic_core_id == 0) {
32        param->axis_ = axis;
33        param->input_dims_ = ndim;
34        param->indices_stride_ = (int*)0xC8020000;
35        param->output_stride_ = (int*)0xC8040000;
36        param->reduction_type_ = reduction_type;
37        PackParam(param, indices_shape, input_shape);
38    }
39    sys_bar(0, core_num); // 初始化参数完成后进行同步
40    fp_scatter_elements_s(input_data, output_data, indices_data, updates_data, param, core_mask);
41}
42
43void main() {
44    int input_shape[2] = {8, 30};
45    int indices_shape[2] = {3, 3};
46    int ndim = 2;
47    int axis = 0;
48    int reduction_type = 0;
49    int core_mask = 0b1111;
50    TestScatterElementsSMC(input_shape, indices_shape, ndim, axis, reduction_type, core_mask);
51}

私有存储版本:

void i8_scatter_elements_p(int8_t *input, int8_t *output, int *indices, int8_t *updates, ScatterElementsParameter *param, int core_mask)
void i16_scatter_elements_p(int16_t *input, int16_t *output, int *indices, int16_t *updates, ScatterElementsParameter *param, int core_mask)
void i32_scatter_elements_p(int *input, int *output, int *indices, int *updates, ScatterElementsParameter *param, int core_mask)
void hp_scatter_elements_p(half *input, half *output, int *indices, half *updates, ScatterElementsParameter *param, int core_mask)
void fp_scatter_elements_p(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
void dp_scatter_elements_p(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
void c64_scatter_elements_p(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
void c128_scatter_elements_p(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)

C调用示例:

 1void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) {
 2    param->indices_stride_[param->input_dims_ - 1] = 1;
 3    int i;
 4    for (i = param->input_dims_ - 1; i > 0; --i) {
 5        param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i];
 6    }
 7    param->output_stride_[param->input_dims_ - 1] = 1;
 8    for (i = param->input_dims_ - 1; i > 0; --i) {
 9        param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i];
10    }
11    param->indices_total_num_ = 1;
12    for (i = 0; i < param->input_dims_; i++) {
13        param->indices_total_num_ *= indices_shape[i];
14    }
15    param->input_total_num_ = 1;
16    for (i = 0; i < param->input_dims_; i++) {
17        param->input_total_num_ *= input_shape[i];
18    }
19    param->input_axis_size_ = input_shape[param->axis_];
20}
21
22void TestScatterElementsL2(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) {
23    void* input_data = (void*)0x10000000; // 私有存储版本地址设置在AM内
24    void* output_data = (void*)0x10001000;
25    int* indices_data = (int*)0x10002000;
26    void* updates_data = (void*)0x10003000;
27    ScatterElementsParameter* param = (ScatterElementsParameter*)0x10004000;
28    param->axis_ = axis;
29    param->input_dims_ = ndim;
30    param->indices_stride_ = (int*)0x10005000;
31    param->output_stride_ = (int*)0x10006000;
32    param->reduction_type_ = reduction_type;
33    PackParam(param, indices_shape, input_shape);
34    fp_scatter_elements_p(input_data, output_data, indices_data, updates_data, param, core_mask);
35}
36
37void main() {
38    int input_shape[2] = {8, 30};
39    int indices_shape[2] = {3, 3};
40    int ndim = 2;
41    int axis = 0;
42    int reduction_type = 0;
43    int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
44    TestScatterElementsL2(input_shape, indices_shape, ndim, axis, reduction_type, core_mask);
45}